import os
import argparse
import numpy as np

from transformers import (
    DataCollatorForSeq2Seq,
    PreTrainedTokenizerFast,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    set_seed,
    AutoModelForSeq2SeqLM,
)
from transformers.trainer_utils import get_last_checkpoint

import dataset_utils
import deepspeed_config


def get_args():
    parser = argparse.ArgumentParser()

    parser.add_argument("--outputdir", type=str)
    parser.add_argument("--maxlen", type=int, default=512, required=False)
    parser.add_argument("--modelname", type=str)
    parser.add_argument("--tasks", nargs="+", default=[])
    parser.add_argument("--train_bsz", type=int, default=16, required=False)
    parser.add_argument("--eval_bsz", type=int, default=24, required=False)
    parser.add_argument("--grad_acc_steps", type=int, default=8, required=False)
    parser.add_argument("--weight_decay", type=float, default=0.01, required=False)
    parser.add_argument("--nepochs", type=int, default=1, required=False)
    parser.add_argument("--maxsteps", type=int)
    parser.add_argument("--warmup_steps", type=int, default=750, required=False)
    parser.add_argument("--bf16", action="store_true")
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--subsample_ds", type=int, default=-1, help="subsample dataset to speed-test")
    parser.add_argument("--lr", type=float, default=5e-5)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--nproc", type=int, default=32)
    parser.add_argument("--clearml_name", type=str, default="test-run")
    parser.add_argument("--maskrate", type=float)
    # parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--use_zero2", action="store_true")
    parser.add_argument("--use_zero3", action="store_true")
    parser.add_argument("--offload_params", action="store_true")
    parser.add_argument("--datadir", type=str)

    return parser


def init_clearml_logging(project, task_name):
    # task = Task.init(project_name=project, task_name=task_name)
    os.environ["CLEARML_PROJECT"] = project
    os.environ["CLEARML_TASK"] = task_name
    os.environ["CLEARML_TASK_NO_REUSE"] = "false"


def get_tokenizer_path(modelname):
    if modelname == "codet5-small":
        tokenizer_path = "./artifacts/tokenizer/codet5/tokenizer"
    elif modelname == "codet5-large":
        tokenizer_path = "./artifacts/tokenizer/codet5/tokenizer"
    else:
        raise AssertionError("modelname not supported")

    return tokenizer_path


def get_checkpoint(outputdir):
    last_checkpoint = None
    if os.path.isdir(outputdir):
        last_checkpoint = get_last_checkpoint(outputdir)
    else:
        return os.makedirs(outputdir, exist_ok=True)

    return last_checkpoint


def get_model_and_tokenizer(modelname, tokenizer_path):
    aimos_proxies = {"http": "http://proxy:8888", "https": "http://proxy:8888"}

    if modelname == "codet5-small":
        tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
        model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-small", proxies=aimos_proxies)
        model.resize_token_embeddings(len(tokenizer))

    elif modelname == "codet5-large":
        tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)
        model = AutoModelForSeq2SeqLM.from_pretrained(
            "Salesforce/codet5-base", proxies=aimos_proxies, local_files_only=True
        )
        model.resize_token_embeddings(len(tokenizer))

    else:
        raise AssertionError("model name not supported")

    return model, tokenizer


if __name__ == "__main__":
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    # --- parse arguments
    parser = get_args()
    args = parser.parse_args()

    print("Arguments: ", args)
    print(f"Local rank: {local_rank}")

    # Newly added
    # deepspeed.init_distributed(dist_backend="nccl")
    # torch.cuda.set_device(os.env["LOCAL_RANK"])
    # world_size = os.env["WORLD_SIZE"]
    # print(f"total GPUs = {world_size}")

    # --- set random seed
    seed = 42
    np.random.seed(seed)
    set_seed(seed)

    # --- initialize logging
    project_name = f"structure-pretraining/pretrain-{args.modelname}"
    init_clearml_logging(project_name, args.clearml_name)

    # --- get deepspeed config
    ds_config = deepspeed_config.get_deepspeed_config(args)
    print("deepspeed config: ", ds_config)

    # -- load model and tokenizer
    print("Loading tokenizer and model")
    tokenizer_path = get_tokenizer_path(args.modelname)
    print(f"Tokenizer: {tokenizer_path}")
    model, tokenizer = get_model_and_tokenizer(args.modelname, tokenizer_path)
    print("Model loaded")
    last_checkpoint = get_checkpoint(args.outputdir)
    print(f"Loading from checkpoint: {last_checkpoint is not None}")

    # -- load dataset and data collator
    def tokenize(examples):
        inputs = examples["inputs"]
        outputs = examples["outputs"]
        model_inputs = tokenizer(
            inputs,
            text_target=outputs,
            truncation=True,
            max_length=args.maxlen,
            return_tensors="pt",
            return_token_type_ids=False,
            padding=True,
        )
        return model_inputs

    ds_train, ds_val = dataset_utils.get_training_dataset(
        args.tasks,
        args.datadir,
        None,
        args.subsample_ds,
    )

    print("Dataset loaded")

    ds_train.set_transform(tokenize)
    ds_val.set_transform(tokenize)

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding=True, max_length=args.maxlen)

    # skip_train flag is used to
    if not args.skip_train:
        print("Starting training")
        # -- set training arguments
        training_args = Seq2SeqTrainingArguments(
            output_dir=args.outputdir,
            learning_rate=args.lr,
            per_device_train_batch_size=args.train_bsz,
            per_device_eval_batch_size=args.eval_bsz,
            gradient_accumulation_steps=args.grad_acc_steps,
            weight_decay=args.weight_decay,
            save_strategy="steps",
            save_steps=1000,
            save_total_limit=10,
            num_train_epochs=args.nepochs,
            # max_steps=args.maxsteps,
            warmup_steps=args.warmup_steps,
            lr_scheduler_type="cosine",
            report_to="tensorboard",
            logging_strategy="steps",
            logging_steps=100,
            do_train=True,
            do_eval=True,
            load_best_model_at_end=True,
            predict_with_generate=False,
            metric_for_best_model="loss",
            greater_is_better=False,
            evaluation_strategy="steps",
            eval_steps=1000,
            seed=seed,
            bf16=args.bf16,
            fp16=args.fp16,
            deepspeed=ds_config,
            local_rank=local_rank,
            gradient_checkpointing=False,
            log_on_each_node=False,
            remove_unused_columns=False,
        )

        trainer = Seq2SeqTrainer(
            model=model,
            tokenizer=tokenizer,
            args=training_args,
            train_dataset=ds_train,
            eval_dataset=ds_val,
            data_collator=data_collator,
        )

        trainer.train(resume_from_checkpoint=last_checkpoint)
        trainer.save_model(output_dir=os.path.join(args.outputdir, "checkpoint-best"))
